from itertools import chain
from typing import Sequence, Any, Optional

import numpy as np
from gym import spaces as spaces

from centralized_verification.envs.fast_grid_world import FastGridWorld


def add_posn(ab, cd):
    a, b = ab
    c, d = cd
    return (a + c, b + d)


def flatten(iterator_to_flatten):
    return list(chain.from_iterable(iterator_to_flatten))


class FastGridWorld2DObs(FastGridWorld):
    """
    A grid world with many pre-computed properties for extremely fast steps
    (well, as fast as you can get with python)

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.

    Agents have five actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left
    """

    def __init__(self, *args, other_agent_obs_radius: Optional[int] = None, **kwargs):
        super(FastGridWorld2DObs, self).__init__(*args, **kwargs)
        self.obs_radius = other_agent_obs_radius
        self.max_x = max(x for (x, y) in self.grid_posns)
        self.max_y = max(y for (x, y) in self.grid_posns)

        self.obs_space = (1 + self.num_agents, self.max_x + 1, self.max_y + 1)

        self.base_map = np.zeros(self.obs_space, dtype=np.int)  # 0: allowable paths, 1-n: agents
        for (x, y) in self.grid_posns:
            self.base_map[0, x, y] = 1

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Box(0, 1, shape=self.obs_space, dtype=np.int)] * self.num_agents

    def project_obs(self, state) -> Sequence[Any]:
        obs_list = []
        for this_agent_num, this_agent_space_num in enumerate(state):
            this_agent_x, this_agent_y = self.grid_posns[this_agent_space_num]
            map = self.base_map.copy()
            for that_agent_num, that_agent_space_num in enumerate(state):
                that_agent_x, that_agent_y = self.grid_posns[that_agent_space_num]

                if ((self.obs_radius is not None)
                        and (abs(this_agent_x - that_agent_x) > self.obs_radius
                             or abs(this_agent_y - that_agent_y) > self.obs_radius)):
                    continue
                else:
                    map[that_agent_num + 1, that_agent_x, that_agent_y] = 1
            obs_list.append(map)

        return obs_list
